Skip to content

Discrete diffusion in diffusers#12911

Draft
kashif wants to merge 61 commits intohuggingface:mainfrom
kashif:diff-d2
Draft

Discrete diffusion in diffusers#12911
kashif wants to merge 61 commits intohuggingface:mainfrom
kashif:diff-d2

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented Jan 4, 2026

What does this PR do?

Add experimental support for discrete token diffusion methods and pipeline

moved llada2 to its own PR: #13226

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif marked this pull request as draft January 4, 2026 23:34
@kashif kashif changed the title Discrete diffusion in diffuers Discrete diffusion in diffusers Jan 4, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

Thanks for this PR!
cc @dg845, can you take a look here? it's related to Dream 7B #12091 you are working on

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Jan 23, 2026

Thanks for the PR! Some preliminary design questions and comments:

  1. I think it could be useful to have a natural place to implement logic which is common to discrete diffusion models. Would something like a DiscreteDiffusionPipelineMixin make sense? For example, I think _resolve_start_token_id, _normalize_prefix_ids, _top_p_filtering, etc. could be candidates as mixin methods. (A possible alternative could be to put the methods in DiffusionPipeline, but it feels a little weird to put the methods there because they aren't applicable to continuous diffusion models.) But maybe this is premature, since we might not know what logic will end up being useful for all (or most) discrete diffusion models.
    1. One motivation for this is that we often want to do semi-autoregressive (SAR) sampling for discrete diffusion models, so it would be useful to have autoregressive sampling techniques such as top-$p$ sampling, top-$k$ sampling, etc. So I think it would be nice to have a place where these methods can be implemented and tested once, and then new discrete diffusion models that support SAR sampling can have easy access to them without having to copy them every time.
  2. Similarly, would it make sense to have a TokenizerTextProcessor class which handles text pre-processing and and post-processing, analogous to how VaeImageProcessor handles image pre- and post-processing? It's probably less necessary as we don't need to do as much normalization as for images, but I could see this being useful for handling e.g. chat templates like in the SDAR and LLaDA 2 pipelines.
    1. As an aside, this could also be useful for existing (continuous) diffusion models, some of which have pretty involved text processing, such as pipelines like SanaPipeline that use a _text_preprocessing method:
      # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
      def _text_preprocessing(self, text, clean_caption=False):
  3. Currently it looks like the pipelines only support denoising models with a transformers-like interface. But we would probably want to implement some discrete diffusion transformers in diffusers, which currently doesn't enforce that interface. So I think we should think about how we can handle both cases gracefully in discrete diffusion pipelines. (One solution could be to simply adopt the transformers interface for all discrete denoising models in diffusers, but that could be unnecessarily restrictive.)

Comment on lines +73 to +77
self.register_to_config(
seq_len=seq_len,
num_inference_steps=num_inference_steps,
inject_start_token=inject_start_token,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we don't register default __call__ arguments to the config, but rather set them as default arguments to the __call__ method:

def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,

Comment on lines +148 to +149
*,
batch_size: int = 1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diffusers pipelines usually don't set __call__ arguments to be keyword-only. (That's not to say that there are no arguments for it, but because other pipelines allow positional arguments I think the expectation is that discrete diffusion pipelines will allow them as well.)

Comment on lines +185 to +190
if seq_len is None:
seq_len = int(self.config.seq_len)
if num_inference_steps is None:
num_inference_steps = int(self.config.num_inference_steps)
if inject_start_token is None:
inject_start_token = bool(self.config.inject_start_token)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on #12911 (comment), this logic could be removed if we don't register default arguments to the config.

Comment on lines +217 to +221
if infill_mask is not None:
if infill_mask.shape != (batch_size, seq_len):
raise ValueError(
f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think input checking and exceptions should be moved to a check_inputs method, which is the usual practice for diffusers pipelines:

def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):

return int(token_id)
return None

def _init_latents(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually name methods which sample latents from the prior distribution prepare_latents:

Comment on lines +102 to +118
if hasattr(self.scheduler, "forward_process") and getattr(self.scheduler, "forward_process") == "uniform":
# Uniform prior over token IDs. Mirror scheduler's exclude-mask behavior.
if getattr(self.scheduler, "exclude_mask_from_uniform", False) and hasattr(
self.scheduler, "_sample_uniform_tokens"
):
return self.scheduler._sample_uniform_tokens(
torch.Size((batch_size, seq_len)),
device=device,
dtype=torch.long,
generator=generator,
)
vocab_size = int(getattr(self.scheduler, "vocab_size", 0))
if vocab_size <= 0:
raise ValueError("Scheduler must define `vocab_size` for uniform prior sampling.")
return torch.randint(
0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long, generator=generator
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: maybe it would be cleaner to define a scheduler method called (say) sample_prior which samples from the prior distribution based on the configured forward_process? So if self.forward_process == "uniform", we would call _sample_uniform_tokens under the hood in sample_prior to sample from a uniform prior distribution.

I think this would allow for more graceful support of other possible forward processes, and make the pipeline code cleaner (as most of the logic would be handled inside the scheduler).


# 3. Prepare latents
input_ids = self.prepare_latents(batch_size, seq_len, generator=generator, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like attention_mask is currently not being used in this pipeline, is this expected?

texts: list[str] | None = None


class TokenDiffusionPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a usage example for TokenDiffusionPipeline?

texts: list[str] | None = None


class HybridTokenDiffusionPipeline(TokenDiffusionPipeline):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we inherit from DiffusionPipeline instead of TokenDiffusionPipeline and copy over common methods as necessary?

texts: list[str] | None = None


class HybridTokenDiffusionPipeline(TokenDiffusionPipeline):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that TokenDiffusionPipeline and HybridTokenDiffusionPipeline implement essentially the same logic (but are intended to be used with different schedulers). Would it be possible to consolidate these two pipelines into a single TokenDiffusionPipeline which works with both TokenDiffusionScheduler and HybridTokenDiffusionScheduler?

Comment on lines +342 to +344
cur_x = x[:, start:end].clone()
cur_position_ids = position_ids[:, start:end]
cur_attn_mask = attn_mask[start:end, :end].unsqueeze(0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cur_x = x[:, start:end].clone()
cur_position_ids = position_ids[:, start:end]
cur_attn_mask = attn_mask[start:end, :end].unsqueeze(0)
block_x = x[:, start:end].clone()
block_position_ids = position_ids[:, start:end]
block_attn_mask = attn_mask[start:end, :end].unsqueeze(0)

nit: rename block-level variables to use the prefix block (e.g. cur_x --> block_x) following the LLaDA 2 pipeline.

Comment on lines +295 to +296
num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length)
total_length = int(num_blocks) * int(block_length)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length)
total_length = int(num_blocks) * int(block_length)
num_blocks = (prompt_length + max_new_tokens + block_length - 1) // block_length
total_length = num_blocks * block_length

Can we remove the int(...) casts here and elsewhere? It makes the code more readable and we are annotating max_new_tokens and block_length as ints so this should be safe.

Comment on lines +385 to +387
transfer_index = step_output.transfer_index
sampled_tokens = step_output.sampled_tokens
sampled_probs = step_output.sampled_probs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
transfer_index = step_output.transfer_index
sampled_tokens = step_output.sampled_tokens
sampled_probs = step_output.sampled_probs

I think we can remove this as these variables are no longer being used.

# Get model predictions only when p_x0 cache is invalidated
if p_x0_cache is None:
sigma_t = self.scheduler.compute_sigma(t, batch_size)
model_input = x_accum[:, -block_length:]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to refactor the block logic here to be more parallel with other block discrete diffusion pipelines such as LLaDA2Pipeline?

current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]

)

@classmethod
def from_pretrained(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to override from_pretrained for the DFlash pipeline?

texts: list[str] | None = None


def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think inlining this method would be more clear as we only call it once (in _get_target_layer_ids).

Comment on lines +320 to +321
self.draft_model.eval()
self.target_model.eval()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.draft_model.eval()
self.target_model.eval()

I think the draft_model and target_model should already be set to eval mode, so we don't need to explicitly call it here.

return sequences, texts
return DFlashPipelineOutput(sequences=sequences, texts=texts)

def _get_block_size(self) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support general draft and target models in _get_block_size and other methods below? My impression is that the DFlash model uses unique modeling logic, so it seems unlikely that we could drop in a random draft or target model and have it work out of the box. So I think it's reasonable to only support existing DFlash checkpoints such as z-lab/Qwen3-8B-DFlash-b16.


The returned tensor is expected to be in (0, 1] and monotone decreasing in `t`.
"""
if self.alpha_schedule == "log_linear":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a reference which uses these $\alpha(t, \epsilon)$ schedules? The references I can find (e.g. Appendix E.1 of the MDLM paper, Table 4 of the MD4 paper) list these alpha schedules without an epsilon argument.

noised = torch.where(block_mask.to(device=device), noised, original_samples)
return noised

def enforce_fixed_masks(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to inline enforce_fixed_masks in the pipeline code (e.g. in TokenDiffusionPipeline), as it is relatively simple and the choice of whether and how to enforce prefix/infill conditioning seems more like a pipeline design choice.

Comment on lines +72 to +75
self.vocab_size = int(vocab_size)
self.mask_token_id = int(mask_token_id)
self.num_train_timesteps = int(num_train_timesteps)
self.t_eps = float(t_eps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.vocab_size = int(vocab_size)
self.mask_token_id = int(mask_token_id)
self.num_train_timesteps = int(num_train_timesteps)
self.t_eps = float(t_eps)

I think this is unnecessary as register_to_config should make these available as self.config.vocab_size, self.config.mask_token_id, etc.

Comment on lines +77 to +81
p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform))
log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform)
log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise)))
self.log_B = float(log_B)
self.log_gamma = float(math.log(float(gamma)))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform))
log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform)
log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise)))
self.log_B = float(log_B)
self.log_gamma = float(math.log(float(gamma)))
p_uniform = max(math.exp(clip_noise), p_uniform)
log_B = gamma * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform)
log_B = np.clip(log_B, -clip_noise, clip_noise)
self.log_B = float(log_B)
self.log_gamma = math.log(gamma)

Can we remove all the int(...) and float(...) casts here and elsewhere?


class HybridTokenDiffusionScheduler(SchedulerMixin, ConfigMixin):
"""
Hybrid-transition discrete token diffusion scheduler.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the __init__ arguments (such as p_uniform, clip_noise, gamma, etc.) be documented in the docstring here, including what they mean and what values might be reasonable for them?

Comment on lines +57 to +58
p_uniform: float = 0.0,
clip_noise: float = 20.0,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default parameters here set the effective p_uniform to $\exp(-20)$. If I understand correctly, p_uniform corresponds to the maximum probability that tokens will transition to another token uniformly at random instead of to the mask token. Since $\exp(-20)$ is very small, I believe the default schedule is thus essentially identical to a normal absorbing/masked diffusion schedule. Is there a reasonable default setting such that this scheduler behaves noticeably different from standard absorbing diffusion (e.g. as implemented by TokenDiffusionScheduler)?

Comment on lines +116 to +117
elif noise_type == "cosine":
return 1.0 - (1.0 - eps) * torch.cos(t * math.pi / 2.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alpha schedules here are different from the alpha schedules defined in TokenDiffusionScheduler, is this intended? See also #12911 (comment).


def _compute_move_chance(self, t: torch.Tensor) -> torch.Tensor:
"""
Compute the probability that a token has been masked (move chance) at continuous time *t*.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the move chance is given by $1 - \alpha(t)$. I think documenting the relationship between the move_chance and alpha_t here would be useful since it would be easier to compare this scheduler with similar schedulers such as TokenDiffusionScheduler.

`torch.Tensor`: Move chance at each timestep value, same shape as *t*.
"""
noise_type = self.config.noise_type
eps = 1e-3
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eps should be configurable via __init__, following TokenDiffusionScheduler.

# Compute move chances at t and s = t - dt
# ------------------------------------------------------------------
move_chance_t = self._compute_move_chance(t).to(dtype=torch.float64)
move_chance_s = self._compute_move_chance(t - dt).to(dtype=torch.float64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think getting s from the timestep schedule would be better in case we want to support non-linspace timestep schedules.

# Subs parameterization: mask token gets -inf, then log_softmax normalizes.
# For unmasked positions, the distribution is forced to be the identity.
logits[..., mask_token_id] = -1e9
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that this computes a log softmax, would using something like torch.special.log_softmax be better here? Also I think it might be more clear to rename this to something like log_probs.

Comment on lines +263 to +265
gumbel_noise = -(torch.rand_like(q_xs, generator=generator) + 1e-10).log()
gumbel_noise = (1e-10 + gumbel_noise).clamp(min=1e-30)
x_block = (q_xs / gumbel_noise).argmax(dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the _gumbel_argmax function defined in scheduling_token_diffusion.py here?

Comment on lines +76 to +79
def step(
self,
draft_tokens: torch.LongTensor,
target_logits: torch.Tensor,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to refactor the DFlash step method such that it follows the standard step interface?

def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.LongTensor,

kashif added 5 commits April 2, 2026 18:40
Pipelines:
- TokenDiffusion: add usage example, remove unused attention_mask, add
  sample_prior to scheduler, inline enforce_fixed_masks
- HybridTokenDiffusion: consolidate into thin wrapper over TokenDiffusion
- SDAR: rename cur_x to block_x, remove int() casts, remove unused vars
- DFlash: inline _get_target_layer_ids, remove eval() calls, remove
  from_pretrained override, simplify model support
- BD3LM: refactor block logic parallel with LLaDA2

Schedulers:
- TokenDiffusion: pre-compute alpha schedule, cleaner if/elif in step,
  add sample_prior method
- HybridTokenDiffusion: remove redundant self.xxx assignments, document
  params, remove int/float casts
- BD3LM: make eps configurable, document move_chance vs alpha_t, use
  log_softmax, get s from timestep schedule, reuse _gumbel_argmax
- DFlash: refactor step to standard model_output interface
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants